{
    "cells": [
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "<i>Copyright (c) Recommenders contributors.</i>\n",
                "\n",
                "<i>Licensed under the MIT License.</i>"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Content-Based Personalization with LightGBM on Spark\n",
                "\n",
                "This notebook provides a quick example of how to train a [LightGBM](https://github.com/Microsoft/Lightgbm) model on Spark using [MMLSpark](https://github.com/Azure/mmlspark) for a content-based personalization scenario.\n",
                "\n",
                "We use the [CRITEO dataset](https://www.kaggle.com/c/criteo-display-ad-challenge), a well known dataset of website ads that can be used to optimize the Click-Through Rate (CTR). The dataset consists of a series of numerical and categorical features and a binary label indicating whether the add has been clicked or not.\n",
                "\n",
                "The model is based on [LightGBM](https://github.com/Microsoft/Lightgbm), which is a gradient boosting framework that uses tree-based learning algorithms. Finally, we take advantage of\n",
                "[MMLSpark](https://github.com/Azure/mmlspark) library, which allows LightGBM to be called in a Spark environment and be computed distributely.\n",
                "\n",
                "This scenario is a good example of **implicit feedback**, where binary labels indicate the interaction between a user and an item. This contrasts with explicit feedback, where the user explicitely rate the content, for example from 1 to 5. \n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Global Settings and Imports"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "This notebook can be run in a Spark environment in a DSVM or in Azure Databricks. For more details about the installation process, please refer to the [setup instructions](../../SETUP.md).\n",
                "\n",
                "**NOTE for Azure Databricks:**\n",
                "* A python script is provided to simplify setting up Azure Databricks with the correct dependencies. Run ```python tools/databricks_install.py -h``` for more details.\n",
                "* MMLSpark should not be run on a cluster with autoscaling enabled. Disable the flag in the Azure Databricks Cluster configuration before running this notebook."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 1,
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "MMLSpark version: com.microsoft.azure:synapseml_2.12:0.9.5\n",
                        "System version: 3.8.0 (default, Nov  6 2019, 21:49:08) \n",
                        "[GCC 7.3.0]\n",
                        "PySpark version: 3.2.0\n"
                    ]
                }
            ],
            "source": [
                "import warnings\n",
                "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
                "\n",
                "import os\n",
                "import sys\n",
                "\n",
                "import pyspark\n",
                "from pyspark.ml import PipelineModel\n",
                "from pyspark.ml.feature import FeatureHasher\n",
                "\n",
                "from recommenders.utils.notebook_utils import is_databricks, store_metadata\n",
                "from recommenders.utils.spark_utils import start_or_get_spark\n",
                "from recommenders.datasets.criteo import load_spark_df\n",
                "from recommenders.datasets.spark_splitters import spark_random_split\n",
                "from recommenders.utils.spark_utils import MMLSPARK_REPO, MMLSPARK_PACKAGE\n",
                "\n",
                "# On Spark >3.0.0,<3.2.0, the following should be set:\n",
                "#     MMLSPARK_PACKAGE = \"com.microsoft.azure:synapseml_2.12:0.9.4\"\n",
                "packages = [MMLSPARK_PACKAGE]\n",
                "repos = [MMLSPARK_REPO]\n",
                "spark = start_or_get_spark(packages=packages, repositories=repos)\n",
                "dbutils = None\n",
                "print(\"MMLSpark version: {}\".format(MMLSPARK_PACKAGE))\n",
                "\n",
                "from synapse.ml.train import ComputeModelStatistics\n",
                "from synapse.ml.lightgbm import LightGBMClassifier\n",
                "\n",
                "print(\"System version: {}\".format(sys.version))\n",
                "print(\"PySpark version: {}\".format(pyspark.version.__version__))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 2,
            "metadata": {
                "tags": [
                    "parameters"
                ]
            },
            "outputs": [],
            "source": [
                "# Criteo data size, it can be \"sample\" or \"full\"\n",
                "DATA_SIZE = \"sample\"\n",
                "\n",
                "# LightGBM parameters\n",
                "# More details on parameters: https://lightgbm.readthedocs.io/en/latest/Parameters-Tuning.html\n",
                "NUM_LEAVES = 32\n",
                "NUM_ITERATIONS = 50\n",
                "LEARNING_RATE = 0.1\n",
                "FEATURE_FRACTION = 0.8\n",
                "\n",
                "# Model name\n",
                "MODEL_NAME = 'lightgbm_criteo.mml'"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Data Preparation\n",
                "\n",
                "The [Criteo Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge) (Criteo DAC) dataset is a well-known industry benchmarking dataset for developing CTR prediction models, and is used frequently by research papers. The original dataset contains over 45M rows, but there is also a down-sampled dataset which has 100,000 rows (this can be used by setting `DATA_SIZE = \"sample\"`). Each row corresponds to a display ad served by Criteo and the first column is indicates whether this ad has been clicked or not.<br><br>\n",
                "The dataset contains 1 label column and 39 feature columns, where 13 columns are integer values (int00-int12) and 26 columns are categorical features (cat00-cat25).<br><br>\n",
                "What the columns represent is not provided, but for this case we can consider the integer and categorical values as features representing the user and / or item content. The label is binary and is an example of implicit feedback indicating a user's interaction with an item. With this dataset we can demonstrate how to build a model that predicts the probability of a user interacting with an item based on available user and item content features.\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 3,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8.58k/8.58k [00:06<00:00, 1.24kKB/s]\n",
                        "                                                                                \r"
                    ]
                },
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>label</th>\n",
                            "      <th>int00</th>\n",
                            "      <th>int01</th>\n",
                            "      <th>int02</th>\n",
                            "      <th>int03</th>\n",
                            "      <th>int04</th>\n",
                            "      <th>int05</th>\n",
                            "      <th>int06</th>\n",
                            "      <th>int07</th>\n",
                            "      <th>int08</th>\n",
                            "      <th>...</th>\n",
                            "      <th>cat16</th>\n",
                            "      <th>cat17</th>\n",
                            "      <th>cat18</th>\n",
                            "      <th>cat19</th>\n",
                            "      <th>cat20</th>\n",
                            "      <th>cat21</th>\n",
                            "      <th>cat22</th>\n",
                            "      <th>cat23</th>\n",
                            "      <th>cat24</th>\n",
                            "      <th>cat25</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>0</td>\n",
                            "      <td>1</td>\n",
                            "      <td>1</td>\n",
                            "      <td>5</td>\n",
                            "      <td>0</td>\n",
                            "      <td>1382</td>\n",
                            "      <td>4</td>\n",
                            "      <td>15</td>\n",
                            "      <td>2</td>\n",
                            "      <td>181</td>\n",
                            "      <td>...</td>\n",
                            "      <td>e5ba7672</td>\n",
                            "      <td>f54016b9</td>\n",
                            "      <td>21ddcdc9</td>\n",
                            "      <td>b1252a9d</td>\n",
                            "      <td>07b5194c</td>\n",
                            "      <td>None</td>\n",
                            "      <td>3a171ecb</td>\n",
                            "      <td>c5c50484</td>\n",
                            "      <td>e8b83407</td>\n",
                            "      <td>9727dd16</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>1</th>\n",
                            "      <td>0</td>\n",
                            "      <td>2</td>\n",
                            "      <td>0</td>\n",
                            "      <td>44</td>\n",
                            "      <td>1</td>\n",
                            "      <td>102</td>\n",
                            "      <td>8</td>\n",
                            "      <td>2</td>\n",
                            "      <td>2</td>\n",
                            "      <td>4</td>\n",
                            "      <td>...</td>\n",
                            "      <td>07c540c4</td>\n",
                            "      <td>b04e4670</td>\n",
                            "      <td>21ddcdc9</td>\n",
                            "      <td>5840adea</td>\n",
                            "      <td>60f6221e</td>\n",
                            "      <td>None</td>\n",
                            "      <td>3a171ecb</td>\n",
                            "      <td>43f13e8b</td>\n",
                            "      <td>e8b83407</td>\n",
                            "      <td>731c3655</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "<p>2 rows × 40 columns</p>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "   label  int00  int01  int02  int03  int04  int05  int06  int07  int08  ...  \\\n",
                            "0      0      1      1      5      0   1382      4     15      2    181  ...   \n",
                            "1      0      2      0     44      1    102      8      2      2      4  ...   \n",
                            "\n",
                            "      cat16     cat17     cat18     cat19     cat20 cat21     cat22     cat23  \\\n",
                            "0  e5ba7672  f54016b9  21ddcdc9  b1252a9d  07b5194c  None  3a171ecb  c5c50484   \n",
                            "1  07c540c4  b04e4670  21ddcdc9  5840adea  60f6221e  None  3a171ecb  43f13e8b   \n",
                            "\n",
                            "      cat24     cat25  \n",
                            "0  e8b83407  9727dd16  \n",
                            "1  e8b83407  731c3655  \n",
                            "\n",
                            "[2 rows x 40 columns]"
                        ]
                    },
                    "execution_count": 3,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "raw_data = load_spark_df(size=DATA_SIZE, spark=spark, dbutils=dbutils)\n",
                "# visualize data\n",
                "raw_data.limit(2).toPandas().head()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Feature Processing\n",
                "The feature data provided has many missing values across both integer and categorical feature fields. In addition the categorical features have many distinct values, so effectively cleaning and representing the feature data is an important step prior to training a model.<br><br>\n",
                "One of the simplest ways of managing both features that have missing values as well as high cardinality is to use the hashing trick. The [FeatureHasher](http://spark.apache.org/docs/latest/ml-features.html#featurehasher) transformer will pass integer values through and will hash categorical features into a sparse vector of lower dimensionality, which can be used effectively by LightGBM.<br><br>\n",
                "First, the dataset is splitted randomly for training and testing and feature processing is applied to each dataset."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 4,
            "metadata": {},
            "outputs": [],
            "source": [
                "raw_train, raw_test = spark_random_split(raw_data, ratio=0.8, seed=42)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 5,
            "metadata": {},
            "outputs": [],
            "source": [
                "columns = [c for c in raw_data.columns if c != 'label']\n",
                "feature_processor = FeatureHasher(inputCols=columns, outputCol='features')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 6,
            "metadata": {},
            "outputs": [],
            "source": [
                "train = feature_processor.transform(raw_train)\n",
                "test = feature_processor.transform(raw_test)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Model Training\n",
                "In MMLSpark, the LightGBM implementation for binary classification is invoked using the `LightGBMClassifier` class and specifying the objective as `\"binary\"`. In this instance, the occurrence of positive labels is quite low, so setting the `isUnbalance` flag to true helps account for this imbalance.<br><br>\n",
                "\n",
                "### Hyper-parameters\n",
                "Below are some of the key [hyper-parameters](https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters-Tuning.rst) for training a LightGBM classifier on Spark:\n",
                "- `numLeaves`: the number of leaves in each tree\n",
                "- `numIterations`: the number of iterations to apply boosting\n",
                "- `learningRate`: the learning rate for training across trees\n",
                "- `featureFraction`: the fraction of features used for training a tree"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 7,
            "metadata": {},
            "outputs": [],
            "source": [
                "lgbm = LightGBMClassifier(\n",
                "    labelCol=\"label\",\n",
                "    featuresCol=\"features\",\n",
                "    objective=\"binary\",\n",
                "    isUnbalance=True,\n",
                "    boostingType=\"gbdt\",\n",
                "    boostFromAverage=True,\n",
                "    baggingSeed=42,\n",
                "    numLeaves=NUM_LEAVES,\n",
                "    numIterations=NUM_ITERATIONS,\n",
                "    learningRate=LEARNING_RATE,\n",
                "    featureFraction=FEATURE_FRACTION\n",
                ")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Model Training and Evaluation"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 8,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "                                                                                \r"
                    ]
                }
            ],
            "source": [
                "model = lgbm.fit(train)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 9,
            "metadata": {},
            "outputs": [],
            "source": [
                "predictions = model.transform(test)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 10,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "                                                                                \r"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "+---------------+------------------+\n",
                        "|evaluation_type|               AUC|\n",
                        "+---------------+------------------+\n",
                        "| Classification|0.6590485347443004|\n",
                        "+---------------+------------------+\n",
                        "\n"
                    ]
                }
            ],
            "source": [
                "evaluator = (\n",
                "    ComputeModelStatistics()\n",
                "    .setScoredLabelsCol(\"prediction\")\n",
                "    .setLabelCol(\"label\")\n",
                "    .setEvaluationMetric(\"AUC\")\n",
                ")\n",
                "\n",
                "result = evaluator.transform(predictions)\n",
                "auc = result.select(\"AUC\").collect()[0][0]\n",
                "result.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 11,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "application/scrapbook.scrap.json+json": {
                            "data": 0.6590485347443004,
                            "encoder": "json",
                            "name": "auc",
                            "version": 1
                        }
                    },
                    "metadata": {
                        "scrapbook": {
                            "data": true,
                            "display": false,
                            "name": "auc"
                        }
                    },
                    "output_type": "display_data"
                }
            ],
            "source": [
                "# Record results for tests - ignore this cell\n",
                "store_metadata(\"auc\", auc)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Model Saving \n",
                "The full pipeline for operating on raw data including feature processing and model prediction can be saved and reloaded for use in another workflow."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 12,
            "metadata": {},
            "outputs": [],
            "source": [
                "# save model\n",
                "pipeline = PipelineModel(stages=[feature_processor, model])\n",
                "pipeline.write().overwrite().save(MODEL_NAME)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 13,
            "metadata": {},
            "outputs": [],
            "source": [
                "# cleanup spark instance\n",
                "if not is_databricks():\n",
                "    spark.stop()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Additional Reading\n",
                "\\[1\\] Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, Wei Chen, Weidong Ma, Qiwei Ye, and Tie-Yan Liu. 2017. LightGBM: A highly efficient gradient boosting decision tree. In Advances in Neural Information Processing Systems. 3146–3154. https://papers.nips.cc/paper/6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree.pdf <br>\n",
                "\\[2\\] MML Spark: https://mmlspark.blob.core.windows.net/website/index.html <br>\n"
            ]
        }
    ],
    "metadata": {
        "celltoolbar": "Tags",
        "kernelspec": {
            "display_name": "Python (reco)",
            "language": "python",
            "name": "reco"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.8.0"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
